{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Generating Protein Sequences with ProtGPT2 Locally\n",
        "This notebook is a companion of chapter 7 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.  \n",
        "The code in this notebook is to generate protein sequences using the [ProtGPT2](https://huggingface.co/nferruz/ProtGPT2) model. It doesn't require hardware acceleration.  \n",
        "More details about the code can be found in the related book's chapter."
      ],
      "metadata": {
        "id": "U7iuQrsVxLk0"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download the ProtGPT2 model from the HF Hub and set up an inference pipeline for it."
      ],
      "metadata": {
        "id": "Deddq7EKyCnA"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A0HtDGt3-2f9"
      },
      "outputs": [],
      "source": [
        "from transformers import pipeline\n",
        "\n",
        "model_id = \"nferruz/ProtGPT2\"\n",
        "protgpt2 = pipeline('text-generation', model=model_id)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Use the pipeline to start generating protein sequences (10 in this example). At the end of the generation process the protein sequences are displayed on the standard output."
      ],
      "metadata": {
        "id": "aFfd2MrsyXit"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "sequences = protgpt2(\"<|endoftext|>\", max_length=100, do_sample=True, top_k=950,\n",
        "                     repetition_penalty=1.2, num_return_sequences=10,\n",
        "                     eos_token_id=0)\n",
        "for seq in sequences:\n",
        "  print(seq)"
      ],
      "metadata": {
        "id": "ArBLDM7c_KcW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define a function to calculate the perplexity metric for the generated results."
      ],
      "metadata": {
        "id": "uxm4McgDzOHp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "\n",
        "def calculate_perplexity(model, tokenizer, text, device):\n",
        "    encodings = tokenizer(text, return_tensors='pt').to(device)\n",
        "\n",
        "    input_ids = encodings.input_ids\n",
        "    target_ids = input_ids.clone()\n",
        "\n",
        "    with torch.no_grad():\n",
        "        outputs = model(input_ids, labels=target_ids)\n",
        "\n",
        "    neg_log_likelihood = outputs.loss\n",
        "\n",
        "    perplexity = torch.exp(neg_log_likelihood)\n",
        "\n",
        "    return perplexity"
      ],
      "metadata": {
        "id": "nvNTt0iIJQ2E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Evaluate the generated results by calculating the perplexity metric for them."
      ],
      "metadata": {
        "id": "wSRkSa5rAg3Y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "device = 'cpu'\n",
        "for seq in sequences:\n",
        "  print(calculate_perplexity(protgpt2.model, protgpt2.tokenizer,\n",
        "                       seq['generated_text'], device))"
      ],
      "metadata": {
        "id": "tCDvGIGHJSC4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Alternatively we can calculate perplexity on a batch of generated protein sequences. Let's define a custom function for this."
      ],
      "metadata": {
        "id": "WRrVc4CZ3goN"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "protgpt2.tokenizer.pad_token = protgpt2.tokenizer.eos_token\n",
        "\n",
        "def calculate_batch_perplexity(input_texts, model, tokenizer):\n",
        "    \"\"\"\n",
        "    Calculate perplexity for a batch of input texts using a pretrained language model.\n",
        "\n",
        "    Args:\n",
        "    - input_texts (List[str]): A list of input texts to evaluate.\n",
        "\n",
        "    Returns:\n",
        "    - List[float]: A list of perplexity scores, one for each input text.\n",
        "    \"\"\"\n",
        "    # Tokenize the batch of texts with padding for uniform length\n",
        "    inputs = tokenizer(\n",
        "        input_texts, return_tensors=\"pt\", padding=True, truncation=True\n",
        "    )\n",
        "\n",
        "    input_ids = inputs[\"input_ids\"]\n",
        "    attention_mask = inputs[\"attention_mask\"]\n",
        "\n",
        "    # Pass the input batch through the model to get logits\n",
        "    with torch.no_grad():\n",
        "        outputs = model(input_ids, attention_mask=attention_mask)\n",
        "        logits = outputs.logits\n",
        "\n",
        "    # Shift the logits and input_ids to align targets correctly\n",
        "    # Logits dimensions are: (batch_size, seq_length, vocab_size)\n",
        "    shift_logits = logits[:, :-1, :]  # Ignore the last token's logits\n",
        "    shift_labels = input_ids[:, 1:]   # Skip the first token in the labels\n",
        "\n",
        "    # Compute log probabilities\n",
        "    log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)\n",
        "\n",
        "    # Gather the log probabilities for the correct tokens\n",
        "    target_log_probs = log_probs.gather(dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)\n",
        "\n",
        "    # Mask out positions corresponding to padding tokens\n",
        "    target_log_probs = target_log_probs * attention_mask[:, 1:].to(log_probs.dtype)\n",
        "\n",
        "    # Compute the mean negative log-likelihood for each sequence\n",
        "    negative_log_likelihood = -target_log_probs.sum(dim=-1) / attention_mask[:, 1:].sum(dim=-1)\n",
        "\n",
        "    # Compute perplexity for each sequence\n",
        "    perplexities = torch.exp(negative_log_likelihood)\n",
        "\n",
        "    # Take mean of perplexities of each batch\n",
        "    mean_perplexity_score = torch.mean(perplexities)\n",
        "\n",
        "    return {\"perplexities\": perplexities, \"mean_perplexity\": mean_perplexity_score}"
      ],
      "metadata": {
        "id": "33tjAT_5-3a5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Execute the ```calculate_batch_perplexity``` function on the generated protein sequences.\n",
        "\n"
      ],
      "metadata": {
        "id": "5dRGmtGy3spU"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "sequence_texts = [seq['generated_text'] for seq in sequences]\n",
        "print(f\"Perplexity scores: {calculate_batch_perplexity(sequence_texts, protgpt2.model, protgpt2.tokenizer)}\")"
      ],
      "metadata": {
        "id": "cE9FtGCk_GIl"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}